function mmalignBCtogene(bcidx,geneidx,bcfname,genefname,uselocal)
%align barcode sequencing images. bcidx and geneidx indicate the channel in
%bc and gene image to be registered.
if ~exist('bcfname','var')
    bcfname='bc';
end

if ~exist('genefname','var')
    genefname='gene';
end

if ~exist('uselocal','var')
    uselocal=0;
end

bcfiles=dir(['*',bcfname,'*.tif']);
genefiles=dir(['*',genefname,'*.tif']);
bcfiles=sort_nat({bcfiles.name});
genefiles=sort_nat({genefiles.name});

%read first cycle images
geneim=imread(genefiles{1},geneidx);
bcim=imread(bcfiles{1},bcidx);


% make this for dapi, do background subtraction first!
geneim=imtophat(geneim,strel('disk',50));
bcim=imtophat(bcim,strel('disk',50));


%register images and find transformation
% [optimizer,metric] = imregconfig('multimodal');
% optimizer.InitialRadius = optimizer.InitialRadius/50;
% optimizer.GrowthFactor=1.01;
% optimizer.MaximumIterations=optimizer.MaximumIterations*10;

% tform = imregtform(bcim(300:end-300,300:end-300), geneim(300:end-300,300:end-300), ...
%     'translation', optimizer, metric,'PyramidLevels',4);
%tform = imregtform(bcim(end-500:end,1:500), geneim(end-500:end,1:500), 'translation', optimizer, metric);
if uselocal==0
    tform=imregcorr(bcim(200:end-200,200:end-200), geneim(200:end-200,200:end-200), 'translation');
   
elseif uselocal==1
    tform=align_local(geneim,bcim);
    % disp('used local')
end

Rfixed=imref2d(size(sum(geneim,3)));
alignedbcim=uint16(imwarp(bcim,tform,'OutputView',Rfixed));

imwrite(double(geneim)./max(double(geneim(:))),'comp.tif');
imwrite(double(alignedbcim)./max(double(alignedbcim(:))),'comp.tif','WriteMode','Append');

%transform all barcode files
for i=1:length(bcfiles)
    bcinfo=imfinfo(bcfiles{i});
    im=zeros(bcinfo(1).Height, bcinfo(1).Width,size(bcinfo,1));
    for n=1:size(bcinfo,1)
        im(:,:,n)=imread(bcfiles{i},n);
    end
    alignedim=uint16(imwarp(im,tform,'OutputView',Rfixed));
    %figure;imshowpair(min((alignedgfpim(:,:,gfpidx)/max(reshape(gfpim(:,:,gfpidx),[],1)))*1,1),min((seqim(:,:,seqidx)/max(reshape(seqim(:,:,seqidx),[],1)))*3,1));
    imwrite(alignedim(:,:,1),['reg',bcfiles{i}]);
    for n=2:size(alignedim,3)
        imwrite(alignedim(:,:,n),['reg',bcfiles{i}],'WriteMode','Append');
    end
end
end

function tform=align_local(templatesum,imagesum)
%align using normxcorr2
resize_factor=5;
intensity_max_thresh=3000;
subsample_rate=6;


block_size=256; % block_size/2 is the max shift we can deal with when using "same" in xcorr2, and block_size-1 when using "full". "Same" should be more accurate since it's less affected by values near the edge, and faster.
block_num_1=floor(size(templatesum,1)/block_size);
block_num_2=floor(size(templatesum,2)/block_size);

if block_num_1*block_size~=size(templatesum,1)
    templatesum=templatesum(1:block_num_1*block_size,:);
    imagesum=imagesum(1:block_num_1*block_size,:);
end
if block_num_2*block_size~=size(templatesum,2)
    templatesum=templatesum(:,1:block_num_2*block_size);
    imagesum=imagesum(:,1:block_num_2*block_size);
end

im1split=mat2cell(templatesum,block_size*ones(1,block_num_1),block_size*ones(1,block_num_2));
im2split=mat2cell(imagesum,block_size*ones(1,block_num_1),block_size*ones(1,block_num_2));

%%
% sort little tiles by total intensity, take the top fractions based on
% subsample_rate. Sensitive to bright spots.
im1split_sum=cellfun(@(x) sum(x,'all'),im1split);
[~,idx]=sort(im1split_sum(:),'descend');

c=zeros(size(im1split{1},1)*resize_factor,size(im1split{1},2)*resize_factor); %when using "same"
c1=repmat(c,1,1,round(numel(im1split)/subsample_rate));
for i=1:round(numel(im1split)/subsample_rate)
    if max(im1split{idx(i)},[],'all')>0
        %c1(:,:,i)=my_xcorr2_avoidsoma(im1split{idx(i)},im2split{idx(i)}, intensity_max_thresh);
        c1(:,:,i)=my_xcorr2_avoidsoma_fft(imresize(im1split{idx(i)},resize_factor), ...
            imresize(im2split{idx(i)},resize_factor), ...
            intensity_max_thresh);

    end
end
c=mean(c1,3,'omitnan');


[ypeak,xpeak] = find(c==max(c(:)));
%%
% fprintf('%s max xcorr is %.2g, min is %.2g, nan count is %u.\n', ...
%     imagename, ...
%     max(c(:)), ...
%     min(c(:)), ...
%     sum(isnan(c(:))));
%yoffSet = (ypeak-size(im1split{1},1))/resize_factor; % when using "full"
%xoffSet = (xpeak-size(im1split{1},2))/resize_factor;
yoffSet = (ypeak-size(im1split{1},1)*resize_factor/2)/resize_factor;% when using "same"
xoffSet = (xpeak-size(im1split{1},2)*resize_factor/2)/resize_factor;
[~,I]=min(abs(xoffSet)+abs(yoffSet));

% alignedim=imtranslate(im,[xoffSet(I),yoffSet(I),0],'bilinear');
tform=affine2d([1,0,0;0,1,0;xoffSet(I) yoffSet(I),1]);
end



function c=my_xcorr2(a,b)
c=conv2(a,rot90(b,2));
end

function c=my_xcorr2_avoidsoma_fft(a,b,thresh,radius)
%xcorr2 using conv_fft2 from matlab exchange, for speed
if ~exist('radius','var')
    % radius=10;%note that this has to account for the scaling factor. This may cause issues when we only have background signal to use for registration.
    radius=3;%note that this has to account for the scaling factor. This may cause issues when we only have background signal to use for registration.
end

if thresh>0
    nan_count=conv_fft2(a<=thresh,rot90(b<=thresh,2),'same');
    total_count=conv_fft2(ones(size(a)),ones(size(b)),'same');


    % only use high frequency features
    % a_mask=a>thresh;
    % b_mask=b>thresh;
    % a=a-imopen(a,strel('disk',radius));
    % b=b-imopen(b,strel('disk',radius));
    % 
    % a(a_mask)=0;
    % b(b_mask)=0;

    % for dapi registration, use low frequency information
    a=imgaussfilt(a,radius);
    b=imgaussfilt(b,radius);


    % set all pixels with >thresh to nan and remove from xcorr
    c=conv_fft2(a,rot90(b,2),'same').*nan_count./total_count;
else
    a=a-imopen(a,strel('disk',radius));
    b=b-imopen(b,strel('disk',radius));
    c=conv_fft2(a,rot90(b,2),'same');
end

end